{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Dash Image Enhancing with ESRGAN.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oSYBcTqRMel-",
        "colab_type": "text"
      },
      "source": [
        "To start this Jupyter Dash app, please run all the cells below. The app will appear inside the last cell."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TQfI1Q95RlSY",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!pip install dash-bootstrap-components jupyter-dash plotly -q"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sZlhBNfNOvMc",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import base64\n",
        "from io import BytesIO\n",
        "\n",
        "import dash\n",
        "import dash_bootstrap_components as dbc\n",
        "import dash_core_components as dcc\n",
        "import dash_html_components as html\n",
        "from dash.dependencies import Input, Output, State\n",
        "import jupyter_dash\n",
        "from PIL import Image\n",
        "import tensorflow as tf\n",
        "import tensorflow_hub as hub"
      ],
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yiGpew_AFcny",
        "colab_type": "text"
      },
      "source": [
        "## Custom components"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7wSRcx41FXA5",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def image_card(src, header=None):\n",
        "    return dbc.Card(\n",
        "        [\n",
        "            dbc.CardHeader(header),\n",
        "            dbc.CardBody(html.Img(src=src, style={\"width\": \"100%\"})),\n",
        "        ]\n",
        "    )"
      ],
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mCc7o9jaFPGx",
        "colab_type": "text"
      },
      "source": [
        "## Helper functions"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1GPFEbp9FYDk",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def preprocess_b64(image_enc):\n",
        "    \"\"\"Preprocess b64 string into TF tensor\"\"\"\n",
        "    decoded = base64.b64decode(image_enc.split(\"base64,\")[-1])\n",
        "    hr_image = tf.image.decode_image(decoded)\n",
        "\n",
        "    if hr_image.shape[-1] == 4:\n",
        "        hr_image = hr_image[..., :-1]\n",
        "\n",
        "    return tf.expand_dims(tf.cast(hr_image, tf.float32), 0)\n",
        "\n",
        "\n",
        "def tf_to_b64(tensor, ext=\"jpeg\"):\n",
        "    buffer = BytesIO()\n",
        "\n",
        "    image = tf.cast(tf.clip_by_value(tensor[0], 0, 255), tf.uint8).numpy()\n",
        "    Image.fromarray(image).save(buffer, format=ext)\n",
        "\n",
        "    encoded = base64.b64encode(buffer.getvalue()).decode(\"utf-8\")\n",
        "\n",
        "    return f\"data:image/{ext};base64, {encoded}\""
      ],
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AQATs6UrO6Bt",
        "colab_type": "text"
      },
      "source": [
        "## App layout"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "o1s1r5W_Fokl",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "app = jupyter_dash.JupyterDash(external_stylesheets=[dbc.themes.BOOTSTRAP])\n",
        "\n",
        "controls = [\n",
        "    dcc.Upload(\n",
        "        dbc.Card(\n",
        "            \"Drag and Drop or Click\",\n",
        "            body=True,\n",
        "            style={\n",
        "                \"textAlign\": \"center\",\n",
        "                \"borderStyle\": \"dashed\",\n",
        "                \"borderColor\": \"black\",\n",
        "            },\n",
        "        ),\n",
        "        id=\"img-upload\",\n",
        "        multiple=False,\n",
        "    )\n",
        "]\n",
        "\n",
        "\n",
        "app.layout = dbc.Container(\n",
        "    [\n",
        "        html.H1(\"Dash Image Enhancing with TensorFlow\"),\n",
        "        html.Hr(),\n",
        "        dbc.Row([dbc.Col(c) for c in controls]),\n",
        "        html.Br(),\n",
        "        dbc.Spinner(\n",
        "            dbc.Row(\n",
        "                [\n",
        "                    dbc.Col(html.Div(id=img_id))\n",
        "                    for img_id in [\"original-img\", \"enhanced-img\"]\n",
        "                ]\n",
        "            )\n",
        "        )\n",
        "    ],\n",
        "    fluid=True,\n",
        ")"
      ],
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LcSpsNG3GqzO",
        "colab_type": "text"
      },
      "source": [
        "## Load ESRGAN model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "h0Q5tQB_Gt7t",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "model = hub.load(\"https://tfhub.dev/captain-pool/esrgan-tf2/1\")"
      ],
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o0zNtwE0Fw4s",
        "colab_type": "text"
      },
      "source": [
        "## Dash Callbacks"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jCxKvhaiFzjy",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "@app.callback(\n",
        "    [Output(\"original-img\", \"children\"), Output(\"enhanced-img\", \"children\")],\n",
        "    [Input(\"img-upload\", \"contents\")],\n",
        "    [State(\"img-upload\", \"filename\")],\n",
        ")\n",
        "def enhance_image(img_str, filename):\n",
        "    if img_str is None:\n",
        "        return dash.no_update, dash.no_update\n",
        "\n",
        "    # sr_str = img_str # PLACEHOLDER\n",
        "    low_res = preprocess_b64(img_str)\n",
        "    super_res = model(tf.cast(low_res, tf.float32))\n",
        "    sr_str = tf_to_b64(super_res)\n",
        "\n",
        "    lr = image_card(img_str, header=\"Original Image\")\n",
        "    sr = image_card(sr_str, header=\"Enhanced Image\")\n",
        "\n",
        "    return lr, sr"
      ],
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DLdfwqxSGC0w",
        "colab_type": "text"
      },
      "source": [
        "## Run the app"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "HxCUD5ZHF0qo",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "app.run_server(mode='inline', height=700)"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}